import os
from typing import List, Tuple

import numpy as np
import pandas as pd
from scipy.stats import mode
from sklearn.preprocessing import PowerTransformer, LabelEncoder, power_transform
from concurrent.futures import ProcessPoolExecutor

from statsmodels.tsa.seasonal import STL

import glob

save_path = "data/preprocessed"


def main():
    # load data
    file_path = "data/raw"
    all_files = glob.glob(file_path + "/*.csv")

    dfs = []
    for file in all_files:
        df = pd.read_csv(file, index_col=None, header=0)
        # drop the No column
        df.drop(["No"], axis=1, inplace=True)
        # lowering the column name
        df.rename(columns = {column: column.lower() for column in df.columns}, inplace=True)
        # set date-time index
        df["Datetime"] = pd.to_datetime(df[["hour", "day", "month", "year"]])
        df.set_index("Datetime", inplace=True)
        df.drop(["hour", "day", "month", "year"], axis=1, inplace=True)

        dfs.append(df)

    replace_categories(dfs)
    fill_na(dfs)
    # drop unnecessary features
    for df in dfs:
        df.drop(["wd", "pm10"], axis=1, inplace=True)

    dfs_removed_seasonality = remove_seasonality(dfs)
    dfs = power_transform(dfs_removed_seasonality)
    save(dfs)

# replace categorical columns
def replace_categories(dfs: List[pd.DataFrame]):
    wd_values = np.unique(np.concatenate([df["wd"].dropna().values for df in dfs]))
    le = LabelEncoder()
    le.fit(wd_values)

    for i in range(len(dfs)):
        mask = ~dfs[i]["wd"].isna()
        dfs[i].loc[mask, "wd"] = le.transform(dfs[i].loc[mask, "wd"])
        dfs[i] = dfs[i].astype({"wd": "float64"})


def fill_na(dfs: List[pd.DataFrame]):
    numerical_columns = [dfs[0].columns[i] for i in range(len(dfs[0].columns))
                         if dfs[0].dtypes.iloc[i] == "float64"]
    categorical_columns = ["wd"]

    for df in dfs:
        for column in numerical_columns:
            nb_null = df[column].isnull().sum()
            window_size = nb_null // 24 + 1
            df.fillna({column: df[column].rolling(f"{window_size}D",
                                                  min_periods=1, center=True).mean()}, inplace=True)
        for column in categorical_columns:
            nb_null = df[column].isnull().sum()
            window_size = nb_null // 24 + 1
            df.fillna({column: df[column].rolling(f"{window_size}D",
                                                  min_periods=1, center=True).apply(
                lambda x: mode(x)[0]
            )}, inplace=True)

    # assert that all the null values have been replaced
    assert all(df[column].isnull().sum() == 0 for df in dfs for column in df.columns)


def remove_seasonality_parallelized(
    arg: Tuple[pd.Series, str, str]
) -> pd.DataFrame:

    feature, feature_name, station = arg
    stl = STL(feature, period=365*24)
    decomposition = stl.fit()
    feature_deseasonalized = decomposition.resid

    return pd.DataFrame({
        "Datetime": feature.index,
        "station": station,
        "feature": feature_deseasonalized,
        "feature_name": feature_name,
    })


def remove_seasonality(
    dfs: List[pd.DataFrame]
) -> List[pd.DataFrame]:

    tasks = []
    for df in dfs:
        station = df["station"].iloc[0]
        for column in df.columns:
            if column == "station":
                continue
            tasks.append((
                df[column],
                column,
                station
            ))

    with ProcessPoolExecutor(max_workers=32) as executor:
        res = list(executor.map(remove_seasonality_parallelized, tasks))

    dfs_deseasonalized = pd.concat(res, ignore_index=True)
    wide_dfs_deseasonalized = dfs_deseasonalized.pivot_table(
        index=["Datetime", "station"],
        columns="feature_name",
        values="feature"
    ).reset_index()
    wide_dfs_deseasonalized.columns.name = None

    # split into list of station-specific dataframe
    list_dfs_deseasonalized = [
        group.set_index("Datetime")
        for _, group in wide_dfs_deseasonalized.groupby("station")
    ]

    return list_dfs_deseasonalized


def power_transform(dfs: List[pd.DataFrame]) -> List[pd.DataFrame]:
    pt = PowerTransformer(method="yeo-johnson")

    dfs_transformed = []
    for df in dfs:
        df_dropped = df.drop(["station"], axis=1)
        df_transformed = pd.DataFrame(
            pt.fit_transform(df_dropped),
            columns=df_dropped.columns,
            index=df_dropped.index
        )
        df_transformed["station"] = df["station"]
        dfs_transformed.append(df_transformed)

    return dfs_transformed


def save(dfs: List[pd.DataFrame]):
    os.makedirs(save_path, exist_ok=True)
    for df in dfs:
        df.to_csv(f"{save_path}/{df['station'].iloc[0]}.csv",
                  float_format="%.5f")


if __name__ == "__main__":
    main()
